import unittest
from sqlbatis import db, transaction, with_transaction, Engin, connection, with_connection, init_db
from config import MYSQL, PGSQL

from sqlexec.sql_support import get_select_key, get_engin

create_table_sql = '''
CREATE TABLE person (
  id BIGSERIAL primary key,
  name varchar(45) NOT NULL,
  age int NOT NULL,
  birth_date date DEFAULT NULL,
  sex smallint DEFAULT NULL,
  grade float DEFAULT NULL,
  point decimal(8,2) DEFAULT NULL,
  money decimal(8,4) DEFAULT NULL,
  create_by bigint DEFAULT NULL,
  create_time timestamp DEFAULT CURRENT_TIMESTAMP,
  update_by bigint DEFAULT NULL,
  update_time timestamp DEFAULT NULL,
  del_flag smallint NOT NULL DEFAULT 0
) 
'''


def create_truncate_table(table):
    if get_engin() == Engin.PostgreSQL:
        cnt = db.do_get("select 1 from pg_class where relname =?", table)
    elif get_engin() == Engin.MySQL:
        cnt = db.do_get("SELECT 1 FROM information_schema.TABLES WHERE table_schema=database() AND table_name=?", table)
    if cnt == 1:
        db.do_execute('truncate table %s' % table)
    else:
        db.do_execute(create_table_sql)


def drop_table():
    db.execute('DROP TABLE IF EXISTS person')


@with_transaction
def test_transaction(rollback: bool = False):
    db.insert('person', name='张三', age=55, birth_date='1968-10-08', sex=0, grade=1.0, point=20.5, money=854.56)
    assert db.get('select count(1) from person limit 1') == 4, 'transaction'
    if rollback:
        1 / 0
    db.save(get_select_key('person_id_seq'), 'person', name='李四', age=55, birth_date='1968-10-08', sex=0, grade=1.0, point=20.5, money=854.56)


def test_transaction2(rollback: bool = False):
    with transaction():
        db.insert('person', name='张三', age=55, birth_date='1968-10-08', sex=0, grade=1.0, point=20.5, money=854.56)
        assert db.get('select count(1) from person') == 6, 'transaction2'
        if rollback:
            1 / 0
        db.save(get_select_key('person_id_seq'), 'person', name='李四', age=55, birth_date='1968-10-08', sex=0, grade=1.0, point=20.5, money=854.56)


class DbTestCase(unittest.TestCase):

    init_db(**MYSQL)

    def setUp(self) -> None:
        create_truncate_table("person")

    def test_insert(self):
        rowcount = db.insert('person', name='张三', age=55, birth_date='1968-10-08', sex=0, grade=1.0, point=20.5, money=854.56)
        assert rowcount == 1, 'insert'
        self.assertEqual(db.get('select count(1) from person'), 1)

    @with_connection
    def test_save(self):
        id2 = db.save(get_select_key('person_id_seq'), 'person', name='李四', age=55, birth_date='1968-10-08', sex=0, grade=1.0, point=20.5,money=854.56)
        assert id2 > 0, 'save'
        self.assertGreater(id2, 0)
        self.assertEqual(db.get('select count(1) from person'), 1)
        db.execute('update person set name=? where id=?', '王五', id2)
        self.assertEqual(db.get('select name from person where id=?', id2), '王五')
        # db.execute('update person set name = :name where id = :id', name='赵六', id=id2)
        # self.assertEqual(db.select_one('select id, name from person where id=:id', id=id2)[0], id2)
        # db.execute('update person set name = :name where id = :id', name='赵六', id=id2)
        # self.assertEqual(db.query_one('select name from person where id=:id', id=id2)['name'], '赵六')

        # args = [
        #     ('张三', 55, '1968-10-08', 0, 1.0, 20.5, 854.56),
        #     ('张三', 55, '1968-10-08', 0, 1.0, 20.5, 854.56)
        # ]
        # db.batch_execute('insert into person(name, age, birth_date, sex, grade, point, money) values(?,?,?,?,?,?,?)', *args)
        # persons = db.select('select id, del_flag from person')
        # assert len(persons) == 4, f'batch_execute: {len(persons)}'
        # persons = db.query('select id, del_flag from person')
        # assert len(persons) == 4, 'batch_execute'
        #
        # persons = db.select('select id, del_flag from person where id=?', id2)
        # assert len(persons) == 1, 'select'
        # persons = db.query('select id, del_flag from person where id=?', id2)
        # assert len(persons) == 1, 'select'
        #
        # persons = db.select('select id, del_flag from person where id=:id', id=id2)
        # assert len(persons) == 1, 'select'
        # persons = db.query('select id, del_flag from person where id=:id', id=id2)
        # assert len(persons) == 1, 'select'
        #
        # db.execute('delete from person where id=?', id2)
        # assert db.get('select count(1) from person') == 3, 'execute delete'
        #
        # try:
        #     test_transaction(rollback=True)
        # except Exception:
        #     print('Rollback.')
        # assert db.get('select count(1) from person') == 3, 'transaction'
        #
        # test_transaction(rollback=False)
        # assert db.get('select count(1) from person') == 5, 'transaction'
        #
        # try:
        #     test_transaction2(rollback=True)
        # except Exception:
        #     print('Rollback.')
        # assert db.get('select count(1) from person') == 5, 'transaction2'
        #
        # test_transaction2(rollback=False)
        # assert db.get('select count(1) from person') == 7, 'transaction2'


if __name__ == '__main__':
    unittest.main()
